from typing import Any, Dict, Tuple, List, Optional, Type, Callable
from functools import partial

import gym
import haiku as hk
import jax
import jax.numpy as jnp
import jax.random as jrng
import optax
from sb3_jax.common.policies import BasePolicy
from sb3_jax.common.utils import print_b
from sb3_jax.common.type_aliases import Schedule
from sb3_jax.du.policies import DiffusionBetaScheduler

from diffgro.common.models.helpers import MLP
from diffgro.common.models.diffusion import UNetDiffusion as TemporalUNet, Diffusion
from diffgro.common.models.utils import fn_foward_noise
from diffgro.experiments.skill_diffuser.networks import SkillPredictor, VectorQuantizer


class SkillDiffuserPolicy(BasePolicy):
    def __init__(
        self,
        observation_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        seed: int,
        # embedding
        horizon: int,
        lang_dim: int,
        skill_dim: int,
        activation_fn: str,
        lr_schedule: Dict[str, Schedule],
        component_kwargs: Dict[str, Any],
    ):
        super(SkillDiffuserPolicy, self).__init__(
            observation_space,
            action_space,
            seed=seed,
        )
        self.obs_dim = observation_space.shape[0]
        self.act_dim = action_space.shape[0]

        self.horizon = horizon
        self.lang_dim = lang_dim
        self.skill_dim = skill_dim

        self.activation_fn = activation_fn

        self.lr_schedule = lr_schedule
        self.component_kwargs = component_kwargs

        self.offset = 0
        self._build(lr_schedule)

    def _build(self, lr_schedule: Dict[str, Tuple[float]]) -> None:
        # make skill_prd
        self.plan = PlannerComponent(
            self.obs_dim,
            self.lang_dim,
            self.skill_dim,
            self.activation_fn,
            **self.component_kwargs["plan"],
            seed=self.seed,
        )

        self.plan.optim = optax.chain(
            optax.masked(
                self.optimizer_class(learning_rate=lr_schedule["skill_prd"]),
                partial(
                    hk.data_structures.map,
                    lambda mname, name, val: mname.startswith("skill_predictor"),
                ),
            ),
            optax.masked(
                self.optimizer_class(learning_rate=lr_schedule["cond_diff"]),
                partial(
                    hk.data_structures.map,
                    lambda mname, name, val: mname.startswith("u_net_diffusion"),
                ),
            ),
        )
        self.plan.optim_state = self.plan.optim.init(self.plan.params)

        # make inv
        self.inv = InverseDynamicsComponent(
            self.obs_dim,
            self.act_dim,
            self.activation_fn,
            **self.component_kwargs["inv"],
            seed=self.seed,
        )
        self.inv.optim = self.optimizer_class(learning_rate=lr_schedule["inv"])
        self.inv.optim_state = self.inv.optim.init(self.inv.params)

    def _predict(
        self,
        obs_stack: jax.Array,  # (B, seq_len, obs_dim)
        lang: Optional[jax.Array],  # (B, lang_dim)
    ) -> jax.Array:
        timestep_in_horizon = obs_stack.shape[1] - 1
        print(timestep_in_horizon)

        # Plan orizon
        if (timestep_in_horizon - 1) % self.horizon == 0:
            self.horizon_plan = self.plan(obs_stack, lang.reshape(1, 1, -1))

        obs = obs_stack[:, timestep_in_horizon]
        next_obs_pred = self.horizon_plan[:, (timestep_in_horizon + 1) % self.horizon]

        act_pred = self.inv(obs, next_obs_pred).reshape(-1)
        return act_pred


#########################################################################################################


class Component:
    def __init__(self, seed):
        self.seed = seed
        self.rng = hk.PRNGSequence(seed)


class PlannerComponent(Component):
    def __init__(
        self,
        obs_dim: int,
        lang_dim: int,
        skill_dim: int,
        activation_fn: str,
        plan_horizon: int,
        n_denoise: int,
        beta_scheduler: str,
        predict_epsilon: bool,
        lmbda: float,
        skill_prd_kwargs: Dict[str, Any],
        vec_quant_kwargs: Dict[str, Any],
        temp_unet_kwargs: Dict[str, Any],
        cond_diff_kwargs: Dict[str, Any],
        seed: int,
    ):
        super(PlannerComponent, self).__init__(seed)

        self.obs_dim = obs_dim
        self.lang_dim = lang_dim
        self.skill_dim = skill_dim
        self.plan_horizon = plan_horizon
        self.activation_fn = activation_fn
        self.n_denoise = n_denoise
        self.predict_epsilon = predict_epsilon
        self.lmbda = lmbda

        self.ddpm_dict = DiffusionBetaScheduler(
            None, None, n_denoise, beta_scheduler
        ).schedule()

        self.skill_prd_kwargs = skill_prd_kwargs
        self.vec_quant_kwargs = vec_quant_kwargs
        self.temp_unet_kwargs = temp_unet_kwargs
        self.cond_diff_kwargs = cond_diff_kwargs

        self._build()

    def _build_skill_prd(self) -> hk.Module:
        return SkillPredictor(self.skill_dim, **self.skill_prd_kwargs)

    def _build_vec_quant(self) -> hk.Module:
        return VectorQuantizer(self.skill_dim, **self.vec_quant_kwargs)

    def _build_cond_diff(self) -> hk.Module:
        unet = TemporalUNet(
            out_dim=self.obs_dim,
            horizon=self.plan_horizon,
            batch_keys=["ctx"],
            activation_fn=self.activation_fn,
            **self.temp_unet_kwargs,
        )
        return Diffusion(
            diffusion=unet,
            n_denoise=self.n_denoise,
            ddpm_dict=self.ddpm_dict,
            predict_epsilon=self.predict_epsilon,  # predict noise
            **self.cond_diff_kwargs,
        )

    def _build(self) -> None:
        dummy_obs = jrng.normal(next(self.rng), shape=(1, 1, self.obs_dim))
        dummy_lang = jrng.normal(next(self.rng), shape=(1, 1, self.lang_dim))

        dummy_x_t = jrng.normal(
            next(self.rng), shape=(1, self.plan_horizon, self.obs_dim)
        )
        dummy_t = jnp.array([[1.0]])

        def fn(
            obs: jax.Array,
            lang: jax.Array,
            x_t: jax.Array,
            t: jax.Array,
            denoise: bool,
            deterministic: bool,
            is_training: bool,
        ):
            skill_prd = self._build_skill_prd()
            vec_quant = self._build_vec_quant()
            cond_diff = self._build_cond_diff()

            skill_pred = skill_prd(obs, lang)
            vq_loss, skill_quant = vec_quant(skill_pred, is_training)
            batch_dict = {"ctx": skill_quant}
            return vq_loss, cond_diff(x_t, batch_dict, t, denoise, deterministic)

        params, self.pi = hk.transform_with_state(fn)
        self.params, self.state = params(
            next(self.rng),
            dummy_obs,
            dummy_lang,
            dummy_x_t,
            dummy_t,
            denoise=False,
            deterministic=False,
            is_training=True,
        )

    @partial(jax.jit, static_argnums=(0, 5, 6, 7))
    def _pi(
        self,
        obs: jax.Array,
        lang: jax.Array,
        x_t: jax.Array,
        t: jax.Array,
        denoise: bool,
        deterministic: bool,
        is_training: bool,
        params: hk.Params,
        state: hk.State,
        rng=None,
    ) -> Tuple[Tuple[jax.Array], Dict[str, jax.Array]]:
        return self.pi(
            params, state, rng, obs, lang, x_t, t, denoise, deterministic, is_training
        )

    def _predict(
        self,
        obs: jax.Array,
        lang: jax.Array,
        x_t: jax.Array,
        t: int,
        deterministic: bool = False,
    ) -> Tuple[jax.Array, Dict[str, jax.Array]]:
        ts = jnp.full((x_t.shape[0], 1), t)
        eps, info = self._pi(
            obs,
            lang,
            x_t,
            ts,
            False,
            deterministic,
            False,
            self.params,
            self.state,
            next(self.rng),
        )[0][1]

        return eps, info

    @partial(jax.jit, static_argnums=(0, 4))
    def _sample(
        self,
        x_t: jax.Array,
        eps: jax.Array,
        t: int,
        deterministic: bool,
        rng=None,
    ) -> jax.Array:
        batch_size = x_t.shape[0]
        noise = (
            jrng.normal(rng, shape=(batch_size, self.horizon, self.emb_dim))
            if not deterministic
            else 0.0
        )

        if self.predict_epsilon:
            x_t = (
                self.ddpm_dict.oneover_sqrta[t]
                * (x_t - self.ddpm_dict.ma_over_sqrtmab_inv[t] * eps)
                + self.ddpm_dict.sqrt_beta_t[t] * noise
            )
        else:
            x_t = (
                self.ddpm_dict.posterior_mean_coef1[t] * eps
                + self.ddpm_dict.posterior_mean_coef2[t] * x_t
                + jnp.exp(0.5 * self.ddpm_dict.posterior_log_beta[t]) * noise
            )
        return x_t

    def __call__(
        self,
        obs: jax.Array,
        lang: jax.Array,
        deterministic: bool = True,
    ) -> Tuple[jax.Array, Dict[str, jax.Array]]:
        batch_size = obs.shape[0]

        x_t = jrng.normal(
            next(self.rng), shape=(batch_size, self.plan_horizon, self.obs_dim)
        )

        cond = jnp.concatenate(
            [
                obs[:, -1][:, None],
                jnp.zeros((batch_size, self.plan_horizon - 1, self.obs_dim)),
            ],
            axis=1,
        )
        mask = jnp.concatenate(
            [
                jnp.ones((batch_size, 1, self.obs_dim)),
                jnp.zeros((batch_size, self.plan_horizon - 1, self.obs_dim)),
            ],
            axis=1,
        )
        for t in range(self.n_denoise, 0, -1):
            x_t = (1 - mask) * x_t + mask * cond
            eps, _ = self._predict(obs, lang, x_t, t, deterministic)
            x_t = self._sample(x_t, eps, t, deterministic, next(self.rng))
        return x_t

    def _load_jax_params(self, params: Dict[str, hk.Params]) -> None:
        print_b("[skill diffuser/planner]: loading params")
        self.params = params["plan_params"]

    @partial(jax.jit, static_argnums=0)
    def _compute_loss(
        self,
        params: hk.Params,
        state: hk.State,
        obs_skill_prd: jax.Array,
        obs_cond_diff: jax.Array,
        lang: jax.Array,
        rng=None,
    ) -> Tuple[jax.Array, Dict[str, jax.Array]]:
        bz = obs_skill_prd.shape[0]
        rng_t, rng_n, rng_p, rng = jax.random.split(rng, num=4)

        # input
        ts = jax.random.randint(rng_t, (bz, 1), minval=1, maxval=self.n_denoise + 1)

        sqrtab = self.ddpm_dict.sqrtab[ts].reshape(bz, 1, 1)
        sqrtmab = self.ddpm_dict.sqrtmab[ts].reshape(bz, 1, 1)

        x_t, noise = fn_foward_noise(obs_cond_diff, sqrtab, sqrtmab, rng_n)

        (vq_loss, (noise_pred, _)), new_state = self._pi(
            obs_skill_prd, lang, x_t, ts, False, False, True, params, state, rng_p
        )

        # loss
        if self.predict_epsilon:  # prediction of noise
            diff_loss = jnp.mean(jnp.square(noise_pred - noise))
        else:  # prediction of original action
            diff_loss = jnp.mean(jnp.square(noise_pred - obs_cond_diff))

        loss = vq_loss + self.lmbda * diff_loss
        return loss, {
            "total_loss": loss,
            "vq_loss": vq_loss,
            "diff_loss": diff_loss,
            "state": new_state,
        }


class InverseDynamicsComponent(Component):
    def __init__(
        self,
        obs_dim: int,
        act_dim: int,
        activation_fn: str,
        hid_dim: int,
        net_arch: List[int],
        seed: int,
    ):
        super(InverseDynamicsComponent, self).__init__(seed)

        self.obs_dim = obs_dim
        self.act_dim = act_dim
        self.activation_fn = activation_fn
        self.hid_dim = hid_dim
        self.net_arch = net_arch

        self._build()

    def _build_inv_dyna(self) -> hk.Module:
        inv_dyna = MLP(
            emb_dim=self.hid_dim,
            out_dim=self.act_dim,
            net_arch=self.net_arch,
            batch_keys=["obs", "next_obs"],
            activation_fn=self.activation_fn,
        )
        return inv_dyna

    def _build(self) -> None:
        dummy_obs = jrng.normal(next(self.rng), shape=(1, self.obs_dim))
        dummy_next_obs = jrng.normal(next(self.rng), shape=(1, self.obs_dim))

        def fn(obs: jax.Array, next_obs: jax.Array):
            inv_dyna = self._build_inv_dyna()

            batch_dict = {"obs": obs, "next_obs": next_obs}
            return inv_dyna(batch_dict)

        params, self.pi = hk.transform(fn)
        self.params = params(
            next(self.rng),
            dummy_obs,
            dummy_next_obs,
        )

    @partial(jax.jit, static_argnums=(0))
    def _pi(
        self,
        obs: jax.Array,
        next_obs: jax.Array,
        params: hk.Params,
        rng=None,
    ) -> Tuple[Tuple[jax.Array], Dict[str, jax.Array]]:
        return self.pi(params, rng, obs, next_obs)

    def __call__(
        self,
        obs: jax.Array,
        next_obs: jax.Array,
    ) -> Tuple[jax.Array, Dict[str, jax.Array]]:

        return self._pi(obs, next_obs, self.params, next(self.rng))

    def _load_jax_params(self, params: Dict[str, hk.Params]) -> None:
        print_b("[skill diffuser/inverse dynamics]: loading params")
        self.params = params["inv_params"]

    @partial(jax.jit, static_argnums=0)
    def _compute_loss(
        self,
        params: hk.Params,
        obs: jax.Array,
        next_obs: jax.Array,
        act: jax.Array,
        rng=None,
    ) -> Tuple[jax.Array, Dict[str, jax.Array]]:
        act_pred = self._pi(obs, next_obs, params, rng)
        loss = jnp.mean(jnp.square(act_pred - act))
        return loss, {"inv_loss": loss}
